{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.TSiTPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TSiT\n",
    "\n",
    "> This is a PyTorch implementation created by Ignacio Oguiza (timeseriesAI@gmail.com) based on ViT (Vision Transformer)\n",
    "     \n",
    "Reference: \n",
    "\n",
    "     Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020).\n",
    "     An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.\n",
    "\n",
    "     This implementation is a modified version of Vision Transformer that is part of the great timm library\n",
    "     (https://github.com/rwightman/pytorch-image-models/blob/72b227dcf57c0c62291673b96bdc06576bb90457/timm/models/vision_transformer.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.utils import *\n",
    "from tsai.models.layers import *\n",
    "from typing import Callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class _TSiTEncoderLayer(nn.Module):\n",
    "    def __init__(self, d_model:int, n_heads:int, attn_dropout:float=0., dropout:float=0, drop_path_rate:float=0., \n",
    "                 mlp_ratio:int=1, qkv_bias:bool=True, act:str='relu', pre_norm:bool=False):\n",
    "        super().__init__()\n",
    "        self.mha =  MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, proj_dropout=dropout, qkv_bias=qkv_bias)\n",
    "        self.attn_norm = nn.LayerNorm(d_model)\n",
    "        self.pwff =  PositionwiseFeedForward(d_model, dropout=dropout, act=act, mlp_ratio=mlp_ratio)\n",
    "        self.ff_norm = nn.LayerNorm(d_model)\n",
    "        self.drop_path = DropPath(drop_path_rate) if drop_path_rate != 0 else nn.Identity()\n",
    "        self.pre_norm = pre_norm\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.pre_norm:\n",
    "            x = self.drop_path(self.mha(self.attn_norm(x))[0]) + x\n",
    "            x = self.drop_path(self.pwff(self.ff_norm(x))) + x\n",
    "        else:\n",
    "            x = self.attn_norm(self.drop_path(self.mha(x)[0]) + x)\n",
    "            x = self.ff_norm(self.drop_path(self.pwff(x)) + x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class _TSiTEncoder(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, depth:int=6, attn_dropout:float=0., dropout:float=0, drop_path_rate:float=0., \n",
    "                 mlp_ratio:int=1, qkv_bias:bool=True, act:str='relu', pre_norm:bool=False):\n",
    "        super().__init__()\n",
    "        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n",
    "        layers = []\n",
    "        for i in range(depth):\n",
    "            layer = _TSiTEncoderLayer(d_model, n_heads, attn_dropout=attn_dropout, dropout=dropout, drop_path_rate=dpr[i], \n",
    "                                      mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, act=act, pre_norm=pre_norm)\n",
    "            layers.append(layer)\n",
    "        self.encoder = nn.Sequential(*layers)\n",
    "        self.norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.norm(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class _TSiTBackbone(Module):\n",
    "    def __init__(self, c_in:int, seq_len:int, depth:int=6, d_model:int=128, n_heads:int=16, d_head:Optional[int]=None, act:str='relu', d_ff:int=256, \n",
    "                 qkv_bias:bool=True, attn_dropout:float=0., dropout:float=0., drop_path_rate:float=0., mlp_ratio:int=1, \n",
    "                 pre_norm:bool=False, use_token:bool=True,  use_pe:bool=True, n_embeds:Optional[list]=None, embed_dims:Optional[list]=None, \n",
    "                 cat_pos:Optional[list]=None, feature_extractor:Optional[Callable]=None):\n",
    "\n",
    "        # Categorical embeddings\n",
    "        if n_embeds is not None:\n",
    "            if embed_dims is None:  \n",
    "                embed_dims = [emb_sz_rule(s) for s in n_embeds]\n",
    "            self.to_cat_embed = MultiEmbedding(c_in, n_embeds, embed_dims=embed_dims, cat_pos=cat_pos)\n",
    "            c_in = c_in + sum(embed_dims) - len(n_embeds)\n",
    "        else:\n",
    "            self.to_cat_embed = nn.Identity()\n",
    "\n",
    "        # Feature extractor\n",
    "        if feature_extractor:\n",
    "            if isinstance(feature_extractor, nn.Module):  self.feature_extractor = feature_extractor\n",
    "            else: self.feature_extractor = feature_extractor(c_in, d_model)\n",
    "            c_in, seq_len = output_size_calculator(self.feature_extractor, c_in, seq_len)\n",
    "        else:\n",
    "            self.feature_extractor = nn.Conv1d(c_in, d_model, 1)\n",
    "        self.transpose = Transpose(1,2)\n",
    "\n",
    "        # Position embedding & token\n",
    "        if use_pe:\n",
    "            self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))\n",
    "        self.use_pe = use_pe\n",
    "        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))\n",
    "        self.use_token = use_token\n",
    "        self.emb_dropout = nn.Dropout(dropout)\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = _TSiTEncoder(d_model, n_heads, depth=depth, qkv_bias=qkv_bias, dropout=dropout,\n",
    "                                    mlp_ratio=mlp_ratio, drop_path_rate=drop_path_rate, act=act, pre_norm=pre_norm)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        # Categorical embeddings\n",
    "        x = self.to_cat_embed(x)\n",
    "\n",
    "        # Feature extractor\n",
    "        x = self.feature_extractor(x)\n",
    "        \n",
    "        # Position embedding & token\n",
    "        x = self.transpose(x)\n",
    "        if self.use_pe: \n",
    "            x = x + self.pos_embed\n",
    "        if self.use_token: # token is concatenated after position embedding so that embedding can be learned using self.supervised learning\n",
    "            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n",
    "        x = self.emb_dropout(x)\n",
    "\n",
    "        # Encoder\n",
    "        x = self.encoder(x)\n",
    "        \n",
    "        # Output\n",
    "        x = x.transpose(1,2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exports\n",
    "class TSiTPlus(nn.Sequential):\n",
    "    r\"\"\"Time series transformer model based on ViT (Vision Transformer):\n",
    "\n",
    "    Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020).\n",
    "    An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.\n",
    "\n",
    "    This implementation is a modified version of Vision Transformer that is part of the grat timm library\n",
    "    (https://github.com/rwightman/pytorch-image-models/blob/72b227dcf57c0c62291673b96bdc06576bb90457/timm/models/vision_transformer.py)\n",
    "\n",
    "    Args:\n",
    "        c_in:               the number of features (aka variables, dimensions, channels) in the time series dataset.\n",
    "        c_out:              the number of target classes.\n",
    "        seq_len:            number of time steps in the time series.\n",
    "        d_model:            total dimension of the model (number of features created by the model).\n",
    "        depth:              number of blocks in the encoder.\n",
    "        n_heads:            parallel attention heads. Default:16 (range(8-16)).\n",
    "        d_head:             size of the learned linear projection of queries, keys and values in the MHA. \n",
    "                            Default: None -> (d_model/n_heads) = 32.\n",
    "        act:                the activation function of positionwise feedforward layer.\n",
    "        d_ff:               the dimension of the feedforward network model. \n",
    "        attn_dropout:       dropout rate applied to the attention sublayer.\n",
    "        dropout:            dropout applied to to the embedded sequence steps after position embeddings have been added and \n",
    "                            to the mlp sublayer in the encoder.\n",
    "        drop_path_rate:     stochastic depth rate.\n",
    "        mlp_ratio:          ratio of mlp hidden dim to embedding dim.\n",
    "        qkv_bias:           determines whether bias is applied to the Linear projections of queries, keys and values in the MultiheadAttention\n",
    "        pre_norm:           if True normalization will be applied as the first step in the sublayers. Defaults to False.\n",
    "        use_token:          if True, the output will come from the transformed token. This is meant to be use in classification tasks.\n",
    "        use_pe:             flag to indicate if positional embedding is used.\n",
    "        n_embeds:           list with the sizes of the dictionaries of embeddings (int).\n",
    "        embed_dims:         list with the sizes of each embedding vector (int).\n",
    "        cat_pos:            list with the position of the categorical variables in the input.\n",
    "        feature_extractor:  an nn.Module or optional callable that will be used to preprocess the time series before \n",
    "                            the embedding step. It is useful to extract features or resample the time series.\n",
    "        flatten:            flag to indicate if the 3d logits will be flattened to 2d in the model's head if use_token is set to False. \n",
    "                            If use_token is False and flatten is False, the model will apply a pooling layer.\n",
    "        concat_pool:        if True the head begins with fastai's AdaptiveConcatPool2d if concat_pool=True; otherwise, it uses traditional average pooling. \n",
    "        fc_dropout:         dropout applied to the final fully connected layer.\n",
    "        use_bn:             flag that indicates if batchnorm will be applied to the head.\n",
    "        bias_init:          values used to initialized the output layer.\n",
    "        y_range:            range of possible y values (used in regression tasks).        \n",
    "        custom_head:        custom head that will be applied to the network. It must contain all kwargs (pass a partial function)\n",
    "        verbose:            flag to control verbosity of the model.\n",
    "\n",
    "    Input:\n",
    "        x: bs (batch size) x nvars (aka features, variables, dimensions, channels) x seq_len (aka time steps)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, c_in:int, c_out:int, seq_len:int, d_model:int=128, depth:int=6, n_heads:int=16, d_head:Optional[int]=None, act:str='relu',\n",
    "                 d_ff:int=256, attn_dropout:float=0., dropout:float=0., drop_path_rate:float=0., mlp_ratio:int=1, qkv_bias:bool=True, pre_norm:bool=False, \n",
    "                 use_token:bool=True, use_pe:bool=True, n_embeds:Optional[list]=None, embed_dims:Optional[list]=None, cat_pos:Optional[list]=None, \n",
    "                 feature_extractor:Optional[Callable]=None, flatten:bool=False, concat_pool:bool=True, fc_dropout:float=0., use_bn:bool=False, \n",
    "                 bias_init:Optional[Union[float, list]]=None, y_range:Optional[tuple]=None, custom_head:Optional[Callable]=None, verbose:bool=True):\n",
    "\n",
    "        if use_token and c_out == 1: \n",
    "            use_token = False\n",
    "            pv(\"use_token set to False as c_out == 1\", verbose)\n",
    "        backbone = _TSiTBackbone(c_in, seq_len, depth=depth, d_model=d_model, n_heads=n_heads, d_head=d_head, act=act,\n",
    "                                 d_ff=d_ff, attn_dropout=attn_dropout, dropout=dropout, drop_path_rate=drop_path_rate, \n",
    "                                 pre_norm=pre_norm, mlp_ratio=mlp_ratio, use_pe=use_pe, use_token=use_token, \n",
    "                                 n_embeds=n_embeds, embed_dims=embed_dims, cat_pos=cat_pos, feature_extractor=feature_extractor)\n",
    "\n",
    "        self.head_nf = d_model\n",
    "        self.c_out = c_out\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "        # Head\n",
    "        if custom_head:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, seq_len)\n",
    "        else:\n",
    "            nf = d_model\n",
    "            layers = []\n",
    "            if use_token: \n",
    "                layers += [TokenLayer()]\n",
    "            elif flatten:\n",
    "                layers += [Reshape(-1)]\n",
    "                nf = nf * seq_len\n",
    "            else:\n",
    "                if concat_pool: nf *= 2\n",
    "                layers = [GACP1d(1) if concat_pool else GAP1d(1)]\n",
    "            if use_bn: layers += [nn.BatchNorm1d(nf)]\n",
    "            if fc_dropout: layers += [nn.Dropout(fc_dropout)]\n",
    "            \n",
    "            # Last layer\n",
    "            linear = nn.Linear(nf, c_out)\n",
    "            if bias_init is not None: \n",
    "                if isinstance(bias_init, float): nn.init.constant_(linear.bias, bias_init)\n",
    "                else: linear.bias = nn.Parameter(torch.as_tensor(bias_init, dtype=torch.float32))\n",
    "            layers += [linear]\n",
    "\n",
    "            if y_range: layers += [SigmoidRange(*y_range)]\n",
    "            head = nn.Sequential(*layers)\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "        \n",
    "        \n",
    "TSiT = TSiTPlus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "model = TSiTPlus(nvars, c_out, seq_len, attn_dropout=.1, dropout=.1, use_token=True)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "model = TSiTPlus(nvars, c_out, seq_len, attn_dropout=.1, dropout=.1, use_token=False)\n",
    "test_eq(model(xb).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "bias_init = np.array([0.8, .2])\n",
    "model = TSiTPlus(nvars, c_out, seq_len, bias_init=bias_init)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "test_eq(model.head[1].bias.data, tensor(bias_init))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use_token set to False as c_out == 1\n"
     ]
    }
   ],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 1\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "bias_init = 8.5\n",
    "model = TSiTPlus(nvars, c_out, seq_len, bias_init=bias_init)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "test_eq(model.head[1].bias.data, tensor([bias_init]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature extractor\n",
    "\n",
    "It's a known fact that transformers cannot be directly applied to long sequences. To avoid this, we have included a way to subsample the sequence to generate a more manageable input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABAYAAABKCAYAAAAoj1bdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPBklEQVR4nO3dfXBVdX7H8c8nCfJQFI1EEKKNBReDWQMGQV0XUXeouj5UXXwCH7bL2nG2u1YcwbZbtchOnc5Ox526nfEZZrBUK2jVsVW7IvgwpQprVhQQluKCJhI2ICKy5JJv/7gnnTSbhCTc5JB7368Zxnvv+d3v73PwzCX3m/M7xxEhAAAAAABQmIrSDgAAAAAAANJDYwAAAAAAgAJGYwAAAAAAgAJGYwAAAAAAgAJGYwAAAAAAgAJGYwAAAAAAgAJGYwAA0G/Zft327OTxTNuvHEKtCtthuyR5/u+2b8pRzm/a3tDq+Rbb38pF7aTeB7an5aoeAAAoLDQGAACpsn2O7bdtf2670fZbts/obp2IeDIipreqG7bH9jRXRFwUEYsONq4r80TEGxExrqdZ2sy30PaCNvVPjYjXc1EfAAAUnpK0AwAACpftoyS9KOlWSU9LOkLSNyX9Ls1cuWS7JCIyaecAAADoCGcMAADS9DVJioglEXEgIr6KiFci4leSZPvm5AyCB5MzCtbbvqC9QsnYN5PHK5OXa23vsX1NO+OLbf/U9g7bmyV9u8321ssUxtpekWTYYfupjuaxPc32NtvzbNdLeqLltTYRzrD9oe2dtp+wPajtfrTKEkmGWyTNlDQ3me+FZPv/LU2wPdD2A7Y/Tf48YHtgsq0l2x22t9uus/3dg/5fAgAAeY3GAAAgTR9JOmB7ke2LbB/Tzpgpkn4tabikeyQts13aWdGImJo8rI6IoRHxVDvDvi/pEkkTJU2S9J1OSt4n6RVJx0gql/SPB5lnpKRSSX8o6ZYOas6U9MeSxijbIPlxZ/uUzPewpCcl/X0y36XtDPtrSWdKmiCpWtLkNrVHShomabSk70n6eQd/7wAAoEDQGAAApCYidks6R1JIekRSg+3nbY9oNWy7pAcioin54r1BbX6730NXJ3W3RkSjpL/rZGyTsl/yR0XEvoh4s5OxktQs6Z6I+F1EfNXBmAdbzf0TSdd1dwc6MFPS/IjYHhENkv5W0g2ttjcl25si4iVJeyTl5PoHAACgf6IxAABIVUSsi4ibI6JcUpWkUZIeaDXkk4iIVs8/TsYcqlGStrap25G5kizpv5M7APzpQWo3RMS+g4xpO3cu9klJndb70rb2b9tc82CvpKE5mhsAAPRDNAYAAIeNiFgvaaGyDYIWo2271fMTJX2ag+nqJJ3Qpm5Hueoj4vsRMUrSn0n6p4PciSA62dai7dwt+/SlpCEtG2yP7GbtT5U9u6G92gAAAL+HxgAAIDW2T0kuhFeePD9B2VPq/6vVsOMk/cj2ANszJFVKeqkL5T+T9EedbH86qVuerLG/q5OcM1oyStqp7Jfz5i7O05EfJHOXKntdgJbrE9RKOtX2hOSChPe2ed/B5lsi6ce2y2wPl3S3pMU9yAcAAAoEjQEAQJq+UPbigqtsf6lsQ2CtpDtajVkl6WRJO5Rdi/+diPhtF2rfK2mR7V22r25n+yOSXlb2i/gaScs6qXVGknGPpOcl3RYRm7s4T0f+WdkLGm5W9uKKCyQpIj6SNF/Sf0raKKnt9QwekzQ+me+5duoukPSupF9Jej/ZtwXdyAUAAAqM//+yTQAADh+2b5Y0OyLOSTsLAABAvuKMAQAAAAAAChiNAQAAAAAAChhLCQAAAAAAKGCcMQAAAAAAQAGjMQAAAAAAQAEr6Y2i9vCQKnqjNIA8MqRyXdoR8t7edZVpRwBygs+L3sfnBYCuW70jIsrSToHc6ZXGQLYp8G7vlAaQN05ZXJN2hLy3pobPYuQHPi96H58XALrOH6edALnFUgIAAAAAAAoYjQEAAAAAAAoYjQEAAAAAAApYL11jAAAAAACAw9fq1auPKykpeVRSlfL7l+bNktZmMpnZNTU129sbQGMAAAAAAFBwSkpKHh05cmRlWVnZzqKiokg7T29pbm52Q0PD+Pr6+kclXdbemHzuigAAAAAA0JGqsrKy3fncFJCkoqKiKCsr+1zZMyPaH9OHeQAAAAAAOFwU5XtToEWynx1+/2cpAQAAAAAAfay+vr542rRp4yRpx44dA4qKiqK0tDQjSe+99966QYMGddi0WLly5ZDHH3/82IULF27NRZaDNgZsPy7pEknbI6LDUw8AAAAAAOivbNXksl6EVne2feTIkQfWr1//oSTNmTNn1NChQw/Mnz//s5btTU1NGjBgQLvvnTp16t6pU6fuzVXWriwlWCjpwlxNCAAAAAAAft9VV11Vcf3115942mmnnXLrrbeWL1++fMiECRNOqaysHD9x4sRTamtrB0rSiy++eOR55503Vso2FWbMmFExefLkceXl5V9fsGDBcd2d96BnDETEStsV3d4jAAAAAADQLXV1dUesWbNmfUlJiRobG4veeeed9QMGDNBzzz135Ny5c8tffvnlX7d9z6ZNmwa9/fbbG3bt2lVcWVlZdeeddzYMHDiwy9dPyNk1BmzfIumW7LMTc1UWAAAAAICCceWVV+4sKcl+VW9sbCy+5pprTtqyZcsg29HU1OT23jN9+vRdgwcPjsGDB2dKS0ubtm3bVjJmzJimrs6Zs7sSRMTDETEpIiZJZbkqCwAAAABAwRg6dGhzy+N58+aNPvfcc7/YuHHjBy+88MKm/fv3t/sdvvXZAcXFxcpkMu02EDrC7QoBAAAAADgM7d69u7i8vHy/JD300EPDe2seGgMAAAAAAByG5s2bV3/vvfeWV1ZWjs9kMr02jyM6vx6B7SWSpkkaLukzSfdExGOdv2dSSO/mKiOAPHX66pzeEQbtWFPT6V1ygH6Dz4vex+cFgK7z6uwS8v6ttrZ2S3V19Y60c/SV2tra4dXV1RXtbevKXQmuy3kiAAAAAABwWGApAQAAAAAABYzGAAAAAAAABYzGAAAAAAAABYzGAAAAAAAABYzGAAAAAAAABYzGAAAAAAAAfWzKlClfW7p06VGtX5s/f/5xM2fOPLG98ZMnTx63cuXKIZJ07rnnjt2xY0dx2zFz5swZdffdd4/obpaD3q4QAAAAAIB8V7OmpiaX9Vafvnp1Z9tnzJjRuGTJktKrrrpqd8trS5cuLb3//vu3Haz2ihUrNuUiYwvOGAAAAAAAoI/dcMMNO1977bVh+/btsyRt2LDhiO3btw9YvHhxaVVVVeXYsWNPvf3220e1997Ro0d/va6urkSS5s2bN7KioqKqpqZm3MaNGwf2JEsvnTGweo/kDb1TG+gzwyXtSDtEPluT054s2meJYxl5YE0Nx3Hvc9oBCgXHMvLBuLQD5IMRI0YcqK6u/vKZZ54ZNmvWrF2LFi0qvfTSS3fed999dSNGjDiQyWR09tlnj1u1atXgKVOmfNVejTfeeGPIs88+W/r+++9/2NTUpAkTJoyfOHHi3u5m6a2lBBsiYlIv1Qb6hO13OY6RDziWkQ84jpEvOJaRD2y/m3aGfHH11Vc3PvXUU8fMmjVr17Jly0ofeeSRLYsWLSpduHDh8Ewm44aGhgG1tbWDOmoMLF++fOjFF1+868gjj2yWpOnTp+/qSQ6WEgAAAAAAkILrr79+11tvvXXUm2++OWTfvn1FZWVlmQcffHDEihUrPvroo48+PP/88z/ft29fr39vpzEAAAAAAEAKhg0b1nzWWWd9MXv27IorrriicefOncWDBw9uLi0tPbB169aS119/fVhn7z///PP3vPTSS0fv2bPHO3fuLHr11VeP7kmO3lpK8HAv1QX6Escx8gXHMvIBxzHyBccy8gHHcQ5de+21jTfeeOOYJUuWbJ44ceK+qqqqvWPGjKk6/vjj99fU1Ozp7L3nnHPO3iuuuKKxqqrq1GOPPbbptNNO+7InGRwRPUsPAAAAAEA/VVtbu6W6urpgLgZaW1s7vLq6uqK9bSwlAAAAAACggOW0MWD7QtsbbG+yfVcuawN9xfYJtpfb/tD2B7ZvSzsT0FO2i23/0vaLaWcBesr20bafsb3e9jrbZ6WdCegu27cnP1estb3E9qC0MwFdYftx29ttr231WqntV21vTP57TJoZcehy1hiwXSzp55IukjRe0nW2x+eqPtCHMpLuiIjxks6U9AOOZfRjt0lal3YI4BD9TNJ/RMQpkqrFMY1+xvZoST+SNCkiqiQVS7o23VRAly2UdGGb1+6S9IuIOFnSL5Ln6MdyecbAZEmbImJzROyX9C+SLs9hfaBPRERdRKxJHn+h7A+go9NNBXSf7XJJ35b0aNpZgJ6yPUzSVEmPSVJE7I+IXamGAnqmRNJg2yWShkj6NOU8QJdExEpJjW1evlzSouTxIkl/0peZcqi5ubnZaYfoC8l+Nne0PZeNgdGStrZ6vk18mUI/Z7tC0kRJq1KOAvTEA5LmqpN/BIB+4CRJDZKeSJbFPGr7D9IOBXRHRHwi6aeSfiOpTtLnEfFKuqmAQzIiIuqSx/WSRqQZ5hCsbWhoGJbvzYHm5mY3NDQMk7S2ozG9dbtCoN+zPVTSUkl/ERG7084DdIftSyRtj4jVtqelHAc4FCWSTpf0w4hYZftnyp6y+jfpxgK6Lll/fbmyja5dkv7V9qyIWJxqMCAHIiJs98tb3WUymdn19fWP1tfXVym/L8zfLGltJpOZ3dGAXDYGPpF0Qqvn5clrQL9je4CyTYEnI2JZ2nmAHviGpMtsXyxpkKSjbC+OiFkp5wK6a5ukbRHRcubWM2ItK/qfb0n6n4hokCTbyySdLYnGAPqrz2wfHxF1to+XtD3tQD1RU1OzXdJlaec4HOSyK/KOpJNtn2T7CGUvqPJ8DusDfcK2lV3Lui4i/iHtPEBPRMRfRkR5RFQo+3n8Gk0B9EcRUS9pq+1xyUsXSPowxUhAT/xG0pm2hyQ/Z1wgLqKJ/u15STclj2+S9G8pZkEO5OyMgYjI2P5zSS8re6XVxyPig1zVB/rQNyTdIOl92+8lr/1VRLyUXiQAKGg/lPRk8ouHzZK+m3IeoFuSZTDPSFqj7N2Pfinp4XRTAV1je4mkaZKG294m6R5J90t62vb3JH0s6er0EiIXHNEvl4MAAAAAAIAcyOcLLAAAAAAAgIOgMQAAAAAAQAGjMQAAAAAAQAGjMQAAAAAAQAGjMQAAAAAAQAGjMQAAAAAAQAGjMQAAAAAAQAGjMQAAAAAAQAH7X+lgXa6cGvwmAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1152x36 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TSTensor(samples:8, vars:3, len:5000, device=cpu)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tsai.data.validation import get_splits\n",
    "from tsai.data.core import get_ts_dls\n",
    "X = np.zeros((10, 3, 5000)) \n",
    "y = np.random.randint(0,2,X.shape[0])\n",
    "splits = get_splits(y)\n",
    "dls = get_ts_dls(X, y, splits=splits)\n",
    "xb, yb = dls.train.one_batch()\n",
    "xb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you try to use TSiTPlus, it's likely you'll get an 'out-of-memory' error.\n",
    "\n",
    "To avoid this you can subsample the sequence reducing the input's length. This can be done in multiple ways. Here are a few examples: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 99])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Separable convolution (to avoid mixing channels)\n",
    "feature_extractor = Conv1d(xb.shape[1], xb.shape[1], ks=100, stride=50, padding=0, groups=xb.shape[1]).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convolution (if you want to mix channels or change number of channels)\n",
    "feature_extractor=MultiConv1d(xb.shape[1], 64, kss=[1,3,5,7,9], keep_original=True).to(default_device())\n",
    "test_eq(feature_extractor.to(xb.device)(xb).shape, (xb.shape[0], 64, xb.shape[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 100])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# MaxPool\n",
    "feature_extractor = nn.Sequential(Pad1d((0, 50), 0), nn.MaxPool1d(kernel_size=100, stride=50)).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 100])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# AvgPool\n",
    "feature_extractor = nn.Sequential(Pad1d((0, 50), 0), nn.AvgPool1d(kernel_size=100, stride=50)).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you decide what type of transform you want to apply, you just need to pass the layer as the feature_extractor attribute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 1000\n",
    "c_out = 2\n",
    "d_model = 128\n",
    "\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "feature_extractor = partial(Conv1d, ks=5, stride=3, padding=0, groups=xb.shape[1])\n",
    "model = TSiTPlus(nvars, c_out, seq_len, d_model=d_model, feature_extractor=feature_extractor)\n",
    "test_eq(model.to(xb.device)(xb).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Categorical variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.utils import alphabet, ALPHABET\n",
    "a = alphabet[np.random.randint(0,3,40)]\n",
    "b = ALPHABET[np.random.randint(6,10,40)]\n",
    "c = np.random.rand(40).reshape(4,1,10)\n",
    "map_a = {k:v for v,k in enumerate(np.unique(a))}\n",
    "map_b = {k:v for v,k in enumerate(np.unique(b))}\n",
    "n_embeds = [len(m.keys()) for m in [map_a, map_b]]\n",
    "szs = [emb_sz_rule(n) for n in n_embeds]\n",
    "a = np.asarray(a.map(map_a)).reshape(4,1,10)\n",
    "b = np.asarray(b.map(map_b)).reshape(4,1,10)\n",
    "inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()\n",
    "feature_extractor = partial(Conv1d, ks=3, padding='same')\n",
    "model = TSiTPlus(3, 2, 10, d_model=64, cat_pos=[1,2], feature_extractor=feature_extractor)\n",
    "test_eq(model(inp).shape, (4,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "IPython.notebook.save_checkpoint();"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "124_models.TSiTPlus.ipynb saved at 2021-11-30 11:54:16.\n",
      "Converted 124_models.TSiTPlus.ipynb.\n",
      "\n",
      "\n",
      "Correct conversion! 😃\n",
      "Total time elapsed 0.089 s\n",
      "Tuesday 30/11/21 11:54:20 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "from tsai.imports import create_scripts\n",
    "from tsai.export import get_nb_name\n",
    "nb_name = get_nb_name()\n",
    "create_scripts(nb_name);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
